import os
import sys
import torch
import numpy as np
import wandb
import datetime
import logging
import provider
import shutil
import torch.nn as nn
import argparse

from pathlib import Path
from tqdm import tqdm
from data_utils.ModeNet40CDataloader import ModelNet40C
from data_utils.ScanObjectnnDataloader import ScanObjectNNSVM
from util import IOStream, AverageMeter,cal_loss
from models.dgcnn import DGCNN
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'models'))

def parse_args():
    '''PARAMETERS'''
    parser = argparse.ArgumentParser('training')
    parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode')
    parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
    parser.add_argument('--batch_size', type=int, default=24, help='batch size in training')
    parser.add_argument('--model', default='dgcnn', help='model name ')
    parser.add_argument('--dataset', default='ModelNet40C', help='[ModelNet40C,ScanObjectNN]')
    parser.add_argument('--num_category',  type=int)
    parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training')
    parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training')
    parser.add_argument('--pretrain', type=bool, default=False, help='Point Number')
    parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
    parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training')
    parser.add_argument('--log_dir', type=str, default=None, help='experiment root')
    parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate')
    parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
    parser.add_argument('--process_data', action='store_true', default=False, help='save data offline')
    parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
    parser.add_argument('--exp_name', type=str, default='dgcnn_shapenet_supervise', help='the name of expriment ')
    parser.add_argument('--pretrain_path', type=str, default='./checkpoints/dgcnn_shapenet_contrastive_less_SVM/models/best_model.pth', help='the path of pretrain model')
    parser.add_argument('--dropout', type=float, default=0.5,
                        help='dropout rate')
    parser.add_argument('--emb_dims', type=int, default=1024, metavar='N',
                        help='Dimension of embeddings')
    parser.add_argument('--k', type=int, default=20, metavar='N',
                        help='Num of nearest neighbors to use')
    parser.add_argument('--seed', type=float, default= 42, help='the seed of random')
    return parser.parse_args()


def inplace_relu(m):
    classname = m.__class__.__name__
    if classname.find('ReLU') != -1:
        m.inplace=True


def test(model, loader, num_class=40):
    mean_correct = []
    class_acc = np.zeros((num_class, 3))
    classifier = model.eval()

    for j, (points, target) in tqdm(enumerate(loader), total=len(loader)):

        if not args.use_cpu:
            points, target = points.cuda(), target.cuda()

        points = points.transpose(2, 1)

        pred = classifier(points)
        pred_choice = pred.data.max(1)[1]

        for cat in np.unique(target.cpu()):
            classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum()
            class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0])
            class_acc[cat, 1] += 1

        correct = pred_choice.eq(target.long().data).cpu().sum()
        mean_correct.append(correct.item() / float(points.size()[0]))

    class_acc[:, 2] = class_acc[:, 0] / class_acc[:, 1]
    class_acc = np.mean(class_acc[:, 2])
    instance_acc = np.mean(mean_correct)

    return instance_acc, class_acc


def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    exp_dir = Path('./log/')
    exp_dir.mkdir(exist_ok=True)
    exp_dir = exp_dir.joinpath('classification')
    exp_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        exp_dir = exp_dir.joinpath(timestr)
    else:
        exp_dir = exp_dir.joinpath(args.log_dir)
    exp_dir.mkdir(exist_ok=True)
    checkpoints_dir = exp_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = exp_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)

    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)

    '''DATA LOADING'''
    log_string('Load dataset ...')

    if args.dataset == 'ModelNet40C':
        print('ModelNet40C')
        trainDataLoader = torch.utils.data.DataLoader(ModelNet40C(1024, "lidar", 1,'train'), batch_size=args.batch_size, shuffle=True,num_workers=0, drop_last=True)
        testDataLoader =  torch.utils.data.DataLoader(ModelNet40C(1024, "lidar", 1,'test'), batch_size=args.batch_size, shuffle=True,num_workers=0, drop_last=False)
        args.num_category = 40
    elif args.dataset=='ScanObjectNN':
        print('ScanObjectNN')
        trainDataLoader = torch.utils.data.DataLoader(ScanObjectNNSVM(1024, 'train'), batch_size=args.batch_size, shuffle=True,num_workers=0, drop_last=True)
        testDataLoader =  torch.utils.data.DataLoader(ScanObjectNNSVM(1024, 'test'), batch_size=args.batch_size, shuffle=True,num_workers=0, drop_last=False)
        args.num_category = 15

    '''MODEL LOADING'''
    wandb.init(project="dgcnn", name=args.exp_name+args.dataset)
    
    num_class = args.num_category
    classifier = DGCNN(args, output_channels=40)
    criterion = cal_loss
    classifier.apply(inplace_relu)
    wandb.watch(classifier)

    '''PRETRAIN MODEL LOADING'''
    model_path = args.pretrain_path
    net = torch.load(model_path)
    classifier.load_state_dict(net, strict=False)
    classifier.linear3 = nn.Linear(256, args.num_category)
    log_string('Load pretrained model succeessfully')

    if not args.use_cpu:
        classifier = classifier.cuda()

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            classifier.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate
        )
    else:
        optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_instance_acc = 0.0
    best_class_acc = 0.0
    start_epoch = 0

    '''TRANING'''
    train_loss = AverageMeter()
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
        wandb_log = {}
        mean_correct = []
        classifier = classifier.train()

        scheduler.step()
        for batch_id, (points, target) in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
            optimizer.zero_grad()
            batch_size = points.size()[0]

            points = points.data.numpy()
            points = provider.random_point_dropout(points)
            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
            points = torch.Tensor(points)
            points = points.transpose(2, 1)

            if not args.use_cpu:
                points, target = points.cuda(), target.cuda()

            logits = classifier(points)
            loss = criterion(logits, target.long())
            pred_choice = logits.data.max(1)[1]

            correct = pred_choice.eq(target.long().data).cpu().sum()
            mean_correct.append(correct.item() / float(points.size()[0]))
            loss.backward()
            optimizer.step()
            global_step += 1
            train_loss.update(loss.item(), batch_size)

        wandb_log['Train Loss'] = train_loss.avg
        train_instance_acc = np.mean(mean_correct)
        wandb_log['Train Accuracy'] = train_instance_acc
        log_string('Train Instance Accuracy: %f' % train_instance_acc)
        save_file = os.path.join(f'checkpoints/{args.exp_name}/models/', 'best_model_finute.pth'.format(epoch=epoch))
        with torch.no_grad():
            instance_acc, class_acc = test(classifier.eval(), testDataLoader, num_class=num_class)

            if (instance_acc >= best_instance_acc):
                best_instance_acc = instance_acc
                best_epoch = epoch + 1

            if (class_acc >= best_class_acc):
                best_class_acc = class_acc
            log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
            log_string('Best Instance Accuracy: %f, Class Accuracy: %f' % (best_instance_acc, best_class_acc))

            if (instance_acc >= best_instance_acc):
                logger.info('Save model...')
                log_string('Saving at %s' % save_file)
                torch.save(classifier.state_dict(), save_file)
            global_epoch += 1
            wandb_log['instance Accuracy'] = instance_acc
            wandb_log['class Accuracy'] = class_acc
            wandb.log(wandb_log)

    logger.info('End of training...')


if __name__ == '__main__':
    args = parse_args()
    main(args)
